import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import scanpy as sc
import wandb
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from tqdm import tqdm
from copy import deepcopy
from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellBert.utils.eval import downstream_eval
from CellBert.utils.data import XDict, clean_batches
from CellBert.utils.mask import InputDropoutMaskBuilder
from CellBert.model import OmicsFormer
# wandb.login()

class CosineAnnealingWarmupRestarts(_LRScheduler):    
    def __init__(self, optimizer, first_cycle_steps, cycle_mult=1., max_lr=0.1, min_lr=0.001,
                 warmup_steps=0, gamma=1., last_epoch=-1):
        assert warmup_steps < first_cycle_steps
        
        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle
        
        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        # set learning rate min_lr
        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)
    
    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
                
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

def main(task=None, config=None):
    global gene_list, batch_labels, seq_list, order_list, coord_list, label_list
    tune_flag = True if config is None else False
    
    group = f"den_{args.pre_model}_{args.dataset}"
    wandb.init(group=group)
    if tune_flag:
        config = wandb.config
    else:
        wandb.config = config
    if task is None:
        task = config['head_type']

    config["gene_list"] = pretrained_gene_list
    config["batch_num"] = batch_labels.max() + 1
    device = torch.device('cuda')

    model = OmicsFormer(**config)
    pretrained_file = f'{args.pre_model}.pt'
    pretrained_model_dict = torch.load(pretrained_file)
    pretrained_model_dict = {k[7:]: v for k, v in pretrained_model_dict.items()} # remove "module."
    try:
        model.load_state_dict(pretrained_model_dict)
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_model_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    model.latent.layers[1].mean = pretrained_model_dict['latent.layers.1.mean'].requires_grad_(True)
    model.latent.layers[1].std = pretrained_model_dict['latent.layers.1.std'].requires_grad_(True)
    print(model)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    # scheduler = ReduceLROnPlateau(optim, 'max', patience=patience, factor=0.9)
    scheduler = CosineAnnealingWarmupRestarts(
        optim,
        first_cycle_steps=15,
        cycle_mult=2,
        max_lr=config['lr'],
        min_lr=1e-7,
        warmup_steps=5,
        gamma=0.9
    )

    train_loss = []
    valid_loss = []
    valid_metric = []
    # for epoch in (pbar := tqdm(range(config['epochs']))):
    for epoch in tqdm(range(config['epochs'])):
        model.train()
        batch_loss = []
        for i in range(len(seq_list)):
            minibatch_loss = []
            for j in range(len(train_list[i])):
                input_dict = {
                    'x_seq': train_list[i][j].to(device),
                    'batch': batch_list[i][train_batch_idx[i][j]].to(device),
                    'coord': coord_list[i][train_batch_idx[i][j]].to(device),
                    'label': label_list[i][train_batch_idx[i][j]].to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, gene_list)  
                optim.zero_grad()
                loss.backward()
                optim.step()
                minibatch_loss.append(loss.item())
            batch_loss.append(sum(minibatch_loss) / len(minibatch_loss))
        train_loss.append(sum(batch_loss) / len(batch_loss))
        scheduler.step()

        with torch.no_grad():
            model.eval()
            batch_loss = []
            for i in range(len(seq_list)):
                valid_epoch = []
                minibatch_loss = []
                pred = []
                label = []
                batch_order = []
                for j in range(len(valid_list[i])):
                    input_dict = {
                        'x_seq': valid_list[i][j].to(device),
                        'batch': batch_list[i][valid_batch_idx[i][j]].to(device),
                        'coord': coord_list[i][valid_batch_idx[i][j]].to(device),
                        'label': label_list[i][valid_batch_idx[i][j]].to(device),
                    }
                    x_dict = XDict(input_dict)
                    out_dict, loss = model(x_dict, gene_list)
                    pred.append(out_dict['pred'])
                    label.append(x_dict['label'])
                    batch_order.append(valid_batch_idx[i][j])
                    minibatch_loss.append(loss.item())
                batch_loss.append(sum(minibatch_loss) / len(minibatch_loss))

                batch_order = torch.cat(batch_order)
                pred = torch.cat(pred)[batch_order]
                label = torch.cat(label)[batch_order]
                valid_mask = valid_mask_list[i]
                valid_scores = downstream_eval(task, pred, label, eval_mask=valid_mask)
                valid_epoch.append(valid_scores['rmse'])
                test_mask = test_mask_list[i]
                test_scores = downstream_eval(task, pred, label, eval_mask=test_mask)

        valid_loss.append(sum(batch_loss) / len(batch_loss))
        valid_metric.append(sum(valid_epoch) / len(valid_epoch))
        # scheduler.step(valid_metric[-1])
        # pbar.set_description(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')

        if task == 'denoising':
            print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
            print(f'Valid RMSE: {valid_scores["rmse"]:.4f} | Test RMSE: {test_scores["rmse"]:.4f}')
            print(f'Valid MAE: {valid_scores["mae"]:.4f} | Test MAE: {test_scores["mae"]:.4f}')

            wandb.log({
                "train": train_loss[-1], 
                "valid": valid_loss[-1],
                "valid_rmse": valid_scores["rmse"],
                "test_rmse": test_scores["rmse"],
                "valid_mae": valid_scores["mae"],
                "test_mae": test_scores["mae"],
            })

        if min(valid_metric) == valid_metric[-1]:
            temp = deepcopy(model.state_dict())
        # if epoch > 0 and min(valid_loss[-20:]) != min(valid_loss):
        #     print('Early stopped.')
        #     break

    # Inference
    model.load_state_dict(temp)
    final_pred = []
    final_label = []
    model.eval()
    with torch.no_grad():
        for i in range(len(seq_list)):
            valid_epoch = []
            minibatch_loss = []
            pred = []
            label = []
            batch_order = []
            for j in range(len(valid_list[i])):
                input_dict = {
                    'x_seq': valid_list[i][j].to(device),
                    'batch': batch_list[i][valid_batch_idx[i][j]].to(device),
                    'coord': coord_list[i][valid_batch_idx[i][j]].to(device),
                    'label': label_list[i][valid_batch_idx[i][j]].to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, gene_list)
                pred.append(out_dict['pred'])
                label.append(x_dict['label'])
                batch_order.append(valid_batch_idx[i][j])
            batch_order = torch.cat(batch_order)
            pred = torch.cat(pred)[batch_order]
            label = torch.cat(label)[batch_order]
            
            final_pred.append(pred.cpu())
            final_label.append(label.cpu())
        del loss, out_dict, model
    torch.cuda.empty_cache()

    if task == 'denoising':
        pred = torch.cat(final_pred)
        label = torch.cat(final_label)
        scores = downstream_eval(task, pred, label, eval_mask=test_mask_list[0])
        print(scores)
        if tune_flag:
            wandb.log({
                'final_rmse': scores['rmse'],
                'final_mae': scores['mae'],
                'final_corr': scores['corr'],
            })
            wandb.finish()
    # elif task == 'denoising':
        # print(f"Corr: {sum(c) / len(c)}, RMSE: {sum(rmse) / len(rmse)}, MAE: {sum(mae) / len(mae)}")
    # del res, y, c, df
    del pred, label, temp
    torch.cuda.empty_cache()

def create_sparse_tensor(x):
    return torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='denoising')
    parser.add_argument("--dataset", type=str, default='5k_pbmc')
    parser.add_argument("--latent_mod", type=str, default='gmvae')
    parser.add_argument("--pre_model", type=str, default='20230510_10M_12M')
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--tune", action='store_true')
    args = parser.parse_args()
    set_seed(args.seed)
    batch_size = args.batch_size
    patience = args.patience
    torch.set_num_threads(32)

    # Data Setup
    task = args.task
    if args.dataset == '5k_pbmc':
        dataset_name = '5k_pbmc_protein_v3_filtered_feature_bc_matrix.h5'
        data = sc.read_10x_h5(f'./data/{dataset_name}')
        data.obs['batch'] = 0
        data.var_names_make_unique()
        gene_list = data.var.index.to_list()
    elif args.dataset == 'jurkat':
        dataset_name = 'jurkat.h5ad'
        data = ad.read_h5ad(f'./data/{dataset_name}')
        data.obs['batch'] = 0
        data.var_names_make_unique()
        gene_list = data.var.index.to_list()
    elif args.dataset == '293t':
        dataset_name = '293t.h5ad'
        data = ad.read_h5ad(f'./data/{dataset_name}')
        data.obs['batch'] = 0
        data.var_names_make_unique()
        gene_list = data.var.index.to_list()
    
    
    with (open(f"{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    pretrained_gene_list = config['gene_list']
    gene_list = [x for x in gene_list if x in pretrained_gene_list]
    data = data[:, gene_list]
    sc.pp.filter_genes(data, min_cells=data.shape[0] * 0.05)
    sc.pp.filter_cells(data, min_counts=1)
    data.raw = data
    sc.pp.normalize_total(data, target_sum=1e4)
    sc.pp.log1p(data)
    print(data.shape)
    gene_list = [x for x in gene_list if x in data.var.index]
    print(len(gene_list))

    mask_builder = InputDropoutMaskBuilder(input_drop_type="mar", valid_drop_rate=0.1, 
                                           test_drop_rate=0.1, seed=args.seed)
    order = np.arange(data.shape[0])
    batch_labels = LabelEncoder().fit_transform(data.obs['batch'])
    train_batch_idx = []
    valid_batch_idx = []
    train_list = []
    valid_list = []
    valid_mask_list = []
    test_mask_list = []
    seq_list = []
    batch_list = []
    order_list = []
    coord_list = []
    label_list = []

    for batch in tqdm(range(batch_labels.max() + 1)):
        x_raw = data[batch_labels == batch].raw.X.astype(float)   
        x = data[batch_labels == batch].X.astype(float)   
        train_mask, valid_mask, test_mask = mask_builder.apply_mask(create_sparse_tensor(x_raw))
        
        train_batch = []
        valid_batch = []
        train_minibatch_list = []
        valid_minibatch_list = []
        train_loader = DataLoader(range(len(range(x.shape[0]))), batch_size=batch_size, shuffle=True)
        for _, minibatch in enumerate(train_loader):
            train_batch.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)[minibatch]))
            train_minibatch_list.append(minibatch)
        train_list.append(train_batch)
        train_batch_idx.append(train_minibatch_list)

        valid_loader = DataLoader(range(len(range(x.shape[0]))), batch_size=batch_size, shuffle=False)
        for _, minibatch in enumerate(valid_loader):
            valid_batch.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)[minibatch]))
            valid_minibatch_list.append(minibatch)
        valid_list.append(valid_batch)
        valid_batch_idx.append(valid_minibatch_list)

        valid_mask_list.append(valid_mask)
        test_mask_list.append(test_mask)
        
        seq_list.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)))
        order_list.append(order[batch_labels == batch])
        batch_list.append(torch.from_numpy(batch_labels[batch_labels == batch]))
        coord_list.append(torch.zeros(x.shape[0], 2)-1)
        # coord_list.append(torch.zeros(x.shape[0], 2))
        # label_list.append(create_sparse_tensor(x_raw))
        label_list.append(torch.from_numpy(x_raw.A))
    
    out_dim = len(gene_list)
    del data, order, x

    if args.tune:
        param_dict = {
            "head_type": {'values': [task]},
            "mask_type": {'values': ['hidden']},
            "dec_mod": {'values': ['mlp']},
            "dec_hid": {'values': [128, 64, 256]},
            "dec_layers": {'values': [2, 3, 4]},
            "model_dropout": {'values': [0, 0.1, 0.3, 0.5, 0.7]},
            "architecture": {'values': ["OmicsFormer"]},
            "epochs": {'values': [3000]},
            "norm": {'values': ["layernorm"]},
            "wd": {'values': [0, 1e-8]},
            "w_li": {'values': [0]},
            "w_en": {'values': [0]},
            "w_ce": {'values': [0]},
            "out_dim": {'values': [out_dim]},
        }

        # ['20230510_50M_12M', '20230506_12M' (36M), '20230510_10M_12M']
        if args.pre_model == '20230510_50M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [512]}
            param_dict["enc_layers"] = {'values': [12]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-6]}
        elif args.pre_model == '20230506_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [512]}
            param_dict["enc_layers"] = {'values': [8]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [2]}
            param_dict["lr"] = {'values': [1e-5]}
        elif args.pre_model == '20230510_10M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [256]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [8]}
            param_dict["lr"] = {'values': [1e-4]}
        elif args.pre_model == '20230513_20M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [384]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4, 1e-5]}
        elif args.pre_model == '20230515_20M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [384]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4]}
        elif args.pre_model == '20230515_10M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [256]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4]}

        param_dict["mask_node_rate"] = {'values': [0.7, 0.1, 0.3, 0.5, 0]}
        param_dict["mask_feature_rate"] = {'values': [0.7, 0.1, 0.3, 0.5, 0]}
        param_dict["drop_node_rate"] = {'values': [0]}
        sweep_configuration = {
            'method': 'bayes',
            'name': 'tuning-ann',
            'metric': {
                'goal': 'maximize',
                'name': 'final_f1_score'
            },
            'parameters': param_dict,
        }
        sweep_id = wandb.sweep(sweep=sweep_configuration, project='CellBert')
        print(sweep_id)
        wandb.agent(sweep_id=sweep_id, function=main, count=1000)

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_20M_12M --tune --dataset 5k_pbmc
        # wandb.agent(sweep_id="fm96zsuy", function=main, count=1000) # 20230515_20M_12M

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_20M_12M --tune --dataset jurkat
        # wandb.agent(sweep_id="vhrroz0o", function=main, count=1000) # 20230515_20M_12M

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_20M_12M --tune --dataset 293t
        # wandb.agent(sweep_id="4you9p3j", function=main, count=1000) # 20230515_20M_12M

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_10M_12M --tune --dataset 5k_pbmc
        # wandb.agent(sweep_id="yxd4acqf", function=main, count=1000) # 20230515_10M_12M

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_10M_12M --tune --dataset jurkat
        # wandb.agent(sweep_id="4you9p3j", function=main, count=1000) # 20230515_20M_12M

        # CUDA_VISIBLE_DEVICES=1 python denoising.py --pre_model 20230515_10M_12M --tune --dataset 293t
        # wandb.agent(sweep_id="4you9p3j", function=main, count=1000) # 20230515_20M_12M

        
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # spa_v1, den, rmse
        # wandb.agent(sweep_id="2n28iri0", function=main, count=1000) # spa_v2 (raw), den, rmse
        # wandb.agent(sweep_id="2n4nw1m3", function=main, count=1000) # spa_v3 raw, den, rmse

    else:
        # config['enc_hid'] = 128
        # config['enc_layers'] = 2
        # config['num_clusters'] = 4
        # config['norm'] = 'batchnorm'

        config['mask_type'] = 'hidden'
        # config['mask_type'] = 'input'
        config['dec_mod'] = 'mlp'
        config['dec_hid'] = 256
        config['dec_layers'] = 4
        config['model_dropout'] = 0.5
        config['mask_node_rate'] = 0.7
        config['mask_feature_rate'] = 0.
        config['drop_node_rate'] = 0.
        config['epochs'] = 2000
        config['lr'] = 1e-4
        config['wd'] = 1e-8
        config['w_li'] = 0.
        config['w_en'] = 0.
        config['w_ce'] = 0.
        config['gumbel_softmax'] = True
        config['head_type'] = task
        config['out_dim'] = out_dim
        main(task, config)

